{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Spam_identification_solution"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 作者：袁宵\n",
    "## 时间：2018/11/16"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 导入依赖库"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "from tensorflow.keras.preprocessing.text import Tokenizer\n",
    "from tensorflow.keras.preprocessing.sequence import pad_sequences\n",
    "from tensorflow.keras.models import Sequential\n",
    "from tensorflow.keras.layers import Embedding, LSTM, Dense, Conv1D, GlobalMaxPool1D\n",
    "from tensorflow.keras.utils import plot_model\n",
    "import numpy as np\n",
    "\n",
    "import os\n",
    "import chardet\n",
    "import re\n",
    "import jieba"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 超参数"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "epochs = 3\n",
    "batch_size = 64\n",
    "max_sequence_length = 150\n",
    "embedding_size = 16\n",
    "num_words = 2000\n",
    "learning_rate = 0.001\n",
    "\n",
    "# RNN 参数\n",
    "LSTM_unit = 32\n",
    "\n",
    "# CNN 参数\n",
    "filters_numbers = 32\n",
    "kernel_size = 5"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 数据预处理"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 查看文件编码"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "GB2312\n"
     ]
    }
   ],
   "source": [
    "a_file_path = 'data/spam/2829'\n",
    "\n",
    "def find_text_encode(a_file_path):\n",
    "    with open(a_file_path, 'rb') as f:\n",
    "        return chardet.detect(f.read())\n",
    "\n",
    "file_info = find_text_encode(a_file_path)\n",
    "\n",
    "file_encoding = file_info['encoding']\n",
    "print(file_encoding)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 清洗文件内容，获取以空格分隔的中文字符串数据"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 过滤非中文字符\n",
    "pattern = re.compile('[^\\u4e00-\\u9fa5]')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "data/spam\n"
     ]
    }
   ],
   "source": [
    "spam_filepath = os.path.join(\"data\", \"spam\")\n",
    "normal_filepath = os.path.join(\"data\", \"normal\")\n",
    "print(spam_filepath)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_clean_sentence(filepath, file_encoding):\n",
    "    fail_file_names_list = []\n",
    "    data = []\n",
    "    filenames = os.listdir(filepath)\n",
    "    for filename in filenames:\n",
    "        file_chinese_content = []\n",
    "        try:\n",
    "            with open(os.path.join(filepath, filename), encoding=file_encoding) as f:\n",
    "                for line in f.readlines():\n",
    "                    line = pattern.sub(\"\", line)\n",
    "                    line_cut = jieba.cut(line)\n",
    "                    file_chinese_content.extend(list(line_cut))\n",
    "        except:\n",
    "            fail_file_names_list.append(filename)  \n",
    "        if len(file_chinese_content) > 0:\n",
    "            file_chinese_content_sequence = \" \".join(file_chinese_content)\n",
    "            data.append(file_chinese_content_sequence)\n",
    "    return data, fail_file_names_list"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Building prefix dict from the default dictionary ...\n",
      "Loading model from cache /tmp/jieba.cache\n",
      "Loading model cost 1.372 seconds.\n",
      "Prefix dict has been built succesfully.\n"
     ]
    }
   ],
   "source": [
    "spam_data, spam_fail_file_names_list = get_clean_sentence(spam_filepath, file_encoding)\n",
    "normal_data, normal_fail_file_names_list = get_clean_sentence(normal_filepath, file_encoding)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "7605 6782\n"
     ]
    }
   ],
   "source": [
    "print(len(spam_data), len(normal_data))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 创建字典"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "50916"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tokenizer = Tokenizer(num_words=num_words)\n",
    "tokenizer.fit_on_texts(spam_data + normal_data)\n",
    "\n",
    "len(tokenizer.index_word)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 序列化数值化 把中文转换成整数值"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "def serialization_numeralization(spam_data, normal_data):\n",
    "    x_train_spam = tokenizer.texts_to_sequences(spam_data)\n",
    "    x_train_normal = tokenizer.texts_to_sequences(normal_data)\n",
    "\n",
    "    x_train_spam_pad = pad_sequences(x_train_spam, max_sequence_length)\n",
    "    x_train_normal_pad = pad_sequences(x_train_normal, max_sequence_length)\n",
    "\n",
    "    x_train = []\n",
    "    y_train = []\n",
    "    for it in x_train_normal_pad:\n",
    "        x_train.append(it)\n",
    "        y_train.append(0)\n",
    "    for it in x_train_spam_pad:\n",
    "        x_train.append(it)\n",
    "        y_train.append(1)\n",
    "    return x_train, y_train"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "14387 14387\n"
     ]
    }
   ],
   "source": [
    "x_train, y_train = serialization_numeralization(spam_data, normal_data)\n",
    "print(len(x_train), len(y_train))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 设计模型 （以下模型二选一）"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## CNN 模型 训练速度快"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = Sequential()\n",
    "model.add(Embedding(input_dim=num_words, output_dim=embedding_size, input_length=max_sequence_length))\n",
    "model.add(Conv1D(filters=filters_numbers, kernel_size=kernel_size))\n",
    "model.add(GlobalMaxPool1D())\n",
    "model.add(Dense(16, activation='relu'))\n",
    "model.add(Dense(1, activation='sigmoid'))\n",
    "model.compile(optimizer='rmsprop',\n",
    "              loss='binary_crossentropy',\n",
    "              metrics=['accuracy'])\n",
    "\n",
    "plot_model(model, to_file=\"Spam_identification_CNN_model.png\", show_shapes=True)          "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## LSTM 模型 训练速度慢"
   ]
  },
  {
   "cell_type": "raw",
   "metadata": {},
   "source": [
    "# 对于具有2个类的单输入模型（二进制分类）：\n",
    "\n",
    "model = Sequential()\n",
    "model.add(Embedding(input_dim=num_words, output_dim=embedding_size, input_length=max_sequence_length))\n",
    "model.add(LSTM(units=32, dropout=0.2, recurrent_dropout=0.2))\n",
    "model.add(Dense(16, activation='relu'))\n",
    "model.add(Dense(1, activation='sigmoid'))\n",
    "model.compile(optimizer='rmsprop',\n",
    "              loss='binary_crossentropy',\n",
    "              metrics=['accuracy'])\n",
    "\n",
    "plot_model(model, to_file=\"Spam_identification_LSTM_model.png\", show_shapes=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 准备数据"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "x_train_np = np.array(x_train)\n",
    "y_train_np = np.array(y_train)\n",
    "y_train_np = y_train_np[:, np.newaxis]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 训练模型"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/b418/anaconda3/envs/yuanxiao/lib/python3.6/site-packages/tensorflow/python/ops/gradients_impl.py:108: UserWarning: Converting sparse IndexedSlices to a dense Tensor of unknown shape. This may consume a large amount of memory.\n",
      "  \"Converting sparse IndexedSlices to a dense Tensor of unknown shape. \"\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train on 11509 samples, validate on 2878 samples\n",
      "Epoch 1/3\n",
      "11509/11509 [==============================] - 16s 1ms/step - loss: 0.2808 - acc: 0.9239 - val_loss: 0.0787 - val_acc: 0.9746\n",
      "Epoch 2/3\n",
      "11509/11509 [==============================] - 15s 1ms/step - loss: 0.0382 - acc: 0.9882 - val_loss: 0.0612 - val_acc: 0.9809\n",
      "Epoch 3/3\n",
      "11509/11509 [==============================] - 15s 1ms/step - loss: 0.0214 - acc: 0.9943 - val_loss: 0.0435 - val_acc: 0.9861\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<tensorflow.python.keras.callbacks.History at 0x7fbd9da01080>"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.fit(x_train_np, y_train_np,\n",
    "          batch_size=batch_size, epochs=epochs,\n",
    "          shuffle=True, validation_split=0.2)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 测试模型"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_spam_filepath = os.path.join(\"data\", \"test\",\"spam\")\n",
    "test_normal_filepath = os.path.join(\"data\", \"test\",\"normal\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_spam_data, test_spam_fail_file_names_list = get_clean_sentence(test_spam_filepath, file_encoding)\n",
    "test_normal_data, test_normal_fail_file_names_list = get_clean_sentence(test_normal_filepath, file_encoding)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "376 376\n"
     ]
    }
   ],
   "source": [
    "x_test, y_test = serialization_numeralization(test_spam_data, test_normal_data)\n",
    "print(len(x_test), len(y_test))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [],
   "source": [
    "x_test_np = np.array(x_test)\n",
    "y_test_np = np.array(y_test)\n",
    "y_test_np = y_test_np[:, np.newaxis]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "376/376 [==============================] - 0s 515us/step\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "[0.007587105084081835, 1.0]"
      ]
     },
     "execution_count": 19,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.evaluate(x_test_np, y_test_np)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
